{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LLaMA ChatInterface Gradio Demo\n",
    "\n",
    "You must have llama model checkpoints to use this notebook. Pls see the [LLaMA](https://github.com/meta-llama/llama-models) documentation and request-download model checkpoints from huggingface.."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install transformers --upgrade\n",
    "!pip install gradio --upgrade\n",
    "!pip install accelerate --upgrade"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "# no nvlink\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\" \n",
    "# use a specific GPU\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "import transformers\n",
    "import gradio as gr\n",
    "import torch\n",
    "\n",
    "# path to the model\n",
    "model = \"/data/llm/llama/Meta-Llama-3.1-8B-Instruct\"\n",
    "model = \"/data/llm/llama/Llama-3.2-1B-Instruct\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model)\n",
    "\n",
    "pipeline = transformers.pipeline(\n",
    "    \"text-generation\",\n",
    "    model=model,\n",
    "    torch_dtype=torch.float16,\n",
    "    device_map=\"auto\",\n",
    ")\n",
    "\n",
    "\n",
    "def echo(message, history, max_new_tokens=100):\n",
    "    print(\"History: \", history)\n",
    "    # pre-pend the history to message\n",
    "    past_messages = \"\"\n",
    "    if history:\n",
    "        past_messages = history[0][-1]\n",
    "    message = past_messages + \" \" + message\n",
    "    print(\"Message\", message)\n",
    "    sequences = pipeline(\n",
    "        message,\n",
    "        do_sample=True,\n",
    "        top_k=1,\n",
    "        num_return_sequences=1,\n",
    "        eos_token_id=tokenizer.eos_token_id,\n",
    "        truncation = False,\n",
    "        max_new_tokens=max_new_tokens,\n",
    "    )\n",
    "\n",
    "    answer = sequences[0]['generated_text']\n",
    "    answer = answer[len(message):]\n",
    "    # find \"\\n\\n\"\n",
    "    # remove string after the index\n",
    "    index = answer.find(\"\\n\\n\")\n",
    "    if index != -1:\n",
    "        answer = answer[:index]\n",
    "    \n",
    "    return answer\n",
    "\n",
    "\n",
    "title = model.split(\"/\")[-1] + \" Chat\"\n",
    "demo = gr.ChatInterface(\n",
    "    fn=echo,\n",
    "    examples=[\"What are the planets in the solar system?\", \"Translate this to Tagalog: Hello, how are you?\"],\n",
    "    title=title,\n",
    ").launch(inbrowser=True)\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mspeech",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
